{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4whv6jgJwagT"
      },
      "source": [
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pinecone-io/examples/blob/master/learn/generation/llm-field-guide/mpt/mpt-7b-huggingface-langchain.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/pinecone-io/examples/blob/master/learn/generation/llm-field-guide/mpt/mpt-7b-huggingface-langchain.ipynb)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JPdQvYmlWmNc"
      },
      "source": [
        "# GPT-J-6B in Hugging Face and LangChain\n",
        "\n",
        "In this notebook we'll explore how we can use the open source **GPT-J-6B** model in both Hugging Face transformers and LangChain.\n",
        "\n",
        "---\n",
        "\n",
        "\ud83d\udea8 _Note that running this on CPU is practically impossible. It will take a very long time. If running on Google Colab you go to **Runtime > Change runtime type > Hardware accelerator > GPU > GPU type > T4** (ideally can go for better GPU for faster speeds like V100 or A100)._\n",
        "\n",
        "---\n",
        "\n",
        "We start by doing a `pip install` of all required libraries."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "K_fRq0BSGMBk",
        "outputId": "a340230c-1a56-4ff5-af48-aa3c21d9e8d6"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[?25l     \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m0.0/7.5 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K     \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m0.1/7.5 MB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:04\u001b[0m\r\u001b[2K     \u001b[91m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m\u001b[91m\u2578\u001b[0m\u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m3.3/7.5 MB\u001b[0m \u001b[31m48.9 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K     \u001b[91m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m\u001b[91m\u2578\u001b[0m \u001b[32m7.5/7.5 MB\u001b[0m \u001b[31m80.5 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K     \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m7.5/7.5 MB\u001b[0m \u001b[31m58.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m251.2/251.2 kB\u001b[0m \u001b[31m28.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m42.2/42.2 kB\u001b[0m \u001b[31m5.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m1.5/1.5 MB\u001b[0m \u001b[31m79.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m167.0/167.0 MB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m29.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m75.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m81.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m49.4/49.4 kB\u001b[0m \u001b[31m6.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "!pip install -qU transformers accelerate einops langchain xformers"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VHQwEeW9Zps2"
      },
      "source": [
        "## Initializing the Hugging Face Pipeline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mElf068NXout"
      },
      "source": [
        "The first thing we need to do is initialize a `text-generation` pipeline with Hugging Face transformers. The Pipeline requires three things that we must initialize first, those are:\n",
        "\n",
        "* A LLM, in this case it will be `EleutherAI/gpt-j-6b`.\n",
        "\n",
        "* The respective tokenizer for the model.\n",
        "\n",
        "* A stopping criteria object.\n",
        "\n",
        "We'll explain these as we get to them, let's begin with our model.\n",
        "\n",
        "We initialize the model and move it to our CUDA-enabled GPU. Using Colab this can take 5-10 minutes to download and initialize the model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 153,
          "referenced_widgets": [
            "da735ac1f30c4eee8292f717725319e5",
            "af656d19845d4562b40afcecab012e40",
            "56cc7ff9a1aa415686b9f551b612ea64",
            "efd837f56145437fa014bc88ff2d9e8f",
            "619e30b0f4ba4421bd66d37c681014f9",
            "947bc5c578b74c839e4a7792cd3069a8",
            "2937d43aec2b48ae98c99cb83873f3c9",
            "1ca5276b1f0a4809be01f77d1865c3db",
            "84493fb40e274330be9fcba1ddba0cce",
            "8f9e87cefcd44f588684c02f80a22357",
            "3531806aa9fd4e69bea7f24036ce7610",
            "c5e5b41a0b264281810c4248e066110d",
            "acc5f209424a42b896f14a19a9a3eb2a",
            "db1370da98294a4dbec1cf8b3ce498cf",
            "6cee774e78ef4ff4b65062e2b763e0a6",
            "aed34383f9da4dce9e69cf318c185ed6",
            "a21069a60b1a474f8cc6e72f0259a94a",
            "f146c5b0687246ea801aa1337f5d9010",
            "3c3c6028cd4d43e9975972b10dcdbbbd",
            "388524ce63e2448fa3d9a67597232329",
            "650e92bf241b47329b48e0b2610dd680",
            "c0b606ba17e841b3aa0953b7189d3785"
          ]
        },
        "id": "ikzdi_uMI7B-",
        "outputId": "b9462ea0-982d-4b25-cf11-cb6e31bbb2fd"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (\u2026)lve/main/config.json:   0%|          | 0.00/930 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "da735ac1f30c4eee8292f717725319e5"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading pytorch_model.bin:   0%|          | 0.00/24.2G [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "c5e5b41a0b264281810c4248e066110d"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "'(ReadTimeoutError(\"HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)\"), '(Request ID: c464143d-6285-41a5-88ca-98dcb1f95915)')' thrown while requesting HEAD https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/generation_config.json\n",
            "WARNING:huggingface_hub.utils._http:'(ReadTimeoutError(\"HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)\"), '(Request ID: c464143d-6285-41a5-88ca-98dcb1f95915)')' thrown while requesting HEAD https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/generation_config.json\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model loaded on cuda:0\n"
          ]
        }
      ],
      "source": [
        "from torch import cuda, bfloat16\n",
        "import transformers\n",
        "\n",
        "device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'\n",
        "model_name = 'EleutherAI/gpt-j-6b'\n",
        "\n",
        "model = transformers.AutoModelForCausalLM.from_pretrained(\n",
        "    model_name,\n",
        "    trust_remote_code=True,\n",
        "    torch_dtype=bfloat16\n",
        ")\n",
        "model.eval()\n",
        "model.to(device)\n",
        "print(f\"Model loaded on {device}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JzX9LqWSX9ot"
      },
      "source": [
        "The pipeline requires a tokenizer which handles the translation of human readable plaintext to LLM readable token IDs. The GPT-J-6B model was trained using the `EleutherAI/gpt-j-6b` tokenizer, which we initialize like so:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 209,
          "referenced_widgets": [
            "6cc5b1e3878b45e891b6a871b16e2bc2",
            "fb4036ead2374f6c8ec93710a570ba67",
            "edf745f1219244468f6ffb91fbe99de3",
            "8e1371bbd5534541945bacf7e85f31ef",
            "b3f2104ab1b24612a2e664acd08c11c7",
            "a41fabc965d44ad597a750a746ae8bc0",
            "08cf8a59e9ca43959259a727af53c24d",
            "4c955f75635448679ecc559011531a5a",
            "74399f28c3d74b4da606421d6a34b216",
            "2ec67b62e29b4dc8830e30b406774c00",
            "65e44b1e6da94380b73efc2a935d2bde",
            "782ca761241b441e872106302d19c5c8",
            "733e18ba75114171aa33cfad35963415",
            "8a99633d4d1e4e45a3e4160b9a2d31e4",
            "218db3111be642a08eb9dd53a50aa2e2",
            "f2896246045040c2ab492897a0c5a6be",
            "10d7cc2bd7f24aa3b06b4b27e48bb759",
            "6128c9e615954ba39fa9e13984436f4d",
            "83ba792c129f4c83a9f3c40c48ca7484",
            "0589af388855488cba5b17574c8a6478",
            "8ef098a55500429ba33957bd29668144",
            "655ce4dddf8c42f5b1bf8befba173b17",
            "d46b466943134ca99d44436e4ec3cfb9",
            "343a7545529f4cca970f1a2d92d22441",
            "c3eaf404b23f4e7dad4023a2b95d700c",
            "fcef9ec444c24b7cbfd46b55bd17ff68",
            "d2a139c61bd545fa94bd429672122ab1",
            "8613291e301c4ca2b34e1163a32a9628",
            "891dfdc7ff9c45b4a04498de9b9d7529",
            "3935c2455d3248048b94eb3cfe3f0477",
            "1a79dfcc678344f5a96715e9a684d253",
            "c1bbe26689b9400f95d96483664b7a0c",
            "82248595dae54d72a97c7bd009f2e1a7",
            "4918bd35af604b80ade80cb6f8077c13",
            "b5b88888c8a94c7bb2e79948d3739845",
            "b9c5eb4a982e444e98445644c2138f83",
            "4cb9bf77ed384e698bbdc3f7e0ecaeb6",
            "91d24833e2d8430f90f78186b3d30f02",
            "9fcd9c7bccb64cc9aa62af17891d33e1",
            "46dc8e7c85ab4f9ba603ac6a33368a31",
            "c310ccde98ec495eadf10256ce2567dc",
            "5d5fa744ac1f4facb1bcd6ccb2e96381",
            "46eabdeff8504bef941f1e7fc51dc2b1",
            "ba89af32fbff4f2498fa82195156d740",
            "1a5b301d45784e2eb4725b6096e0a213",
            "28c3cfcbab714ba396df2f1d808c449b",
            "1e06c22634464b2f89302f30d1ebf5cc",
            "80fcbbbe77994a51980ec901c5c788ac",
            "6e0238deae9145098dbcb880f859a2b7",
            "b4c8cf756c304d2d97eef46972a752b6",
            "ee527de37006411f96fe1c5479caa20f",
            "b5a39ec896c846c4af2c61d85f71dc70",
            "7613afd09bfd423bb4bd8818bfbde289",
            "b8d5e569c2d2468589376c305611935b",
            "1113a43553ee4546ad5cd89775bd026f",
            "e6f5d27b4ac843c5a543999bc02f9892",
            "674a6272d6604811b9514fd84c4a421c",
            "be736af0631645e18aa50bd574fad018",
            "0f1b097d11d046b5b4b0b734168fb0b5",
            "ed4d4181f68f44a1a33741d8402f22e2",
            "7c20001ece1c4bdcb1be8fc62a2da79a",
            "f5a0600d1cb046ed816cf181dec24b14",
            "2005040330924c3dacd4b846cd2096c2",
            "22110b80d7ce4fba886aec1cc58b34e7",
            "98cb870caa314eda88ccdbc148843181",
            "6eca89106cad4354adaefae24cb8e278"
          ]
        },
        "id": "v0iPv1GDGxgT",
        "outputId": "2ab64f80-ca1d-4608-9f9e-5dc2a3cc8f02"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (\u2026)okenizer_config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "6cc5b1e3878b45e891b6a871b16e2bc2"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (\u2026)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "782ca761241b441e872106302d19c5c8"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (\u2026)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "d46b466943134ca99d44436e4ec3cfb9"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (\u2026)/main/tokenizer.json:   0%|          | 0.00/1.37M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "4918bd35af604b80ade80cb6f8077c13"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (\u2026)in/added_tokens.json:   0%|          | 0.00/4.04k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "1a5b301d45784e2eb4725b6096e0a213"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (\u2026)cial_tokens_map.json:   0%|          | 0.00/357 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "e6f5d27b4ac843c5a543999bc02f9892"
            }
          },
          "metadata": {}
        }
      ],
      "source": [
        "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XL7G9Sr3uxdz"
      },
      "source": [
        "Finally we need to define the _stopping criteria_ of the model. The stopping criteria allows us to specify *when* the model should stop generating text. If we don't provide a stopping criteria the model just goes on a bit of a tangent after answering the initial question."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "UG3R0LBQevQW"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from transformers import StoppingCriteria, StoppingCriteriaList\n",
        "\n",
        "# gpt-j-6b is trained to add \"<|endoftext|>\" at the end of generations\n",
        "stop_token_ids = tokenizer.convert_tokens_to_ids([\"<|endoftext|>\"])\n",
        "\n",
        "# define custom stopping criteria object\n",
        "class StopOnTokens(StoppingCriteria):\n",
        "    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:\n",
        "        for stop_id in stop_token_ids:\n",
        "            if input_ids[0][-1] == stop_id:\n",
        "                return True\n",
        "        return False\n",
        "\n",
        "stopping_criteria = StoppingCriteriaList([StopOnTokens()])"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "stop_token_ids"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hK7dmcEVVNqD",
        "outputId": "bfcc438d-d48d-42b3-fc44-b5122f101d26"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[50256]"
            ]
          },
          "metadata": {},
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bNysQFtPoaj7"
      },
      "source": [
        "Now we're ready to initialize the HF pipeline. There are a few additional parameters that we must define here. Comments explaining these have been included in the code."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "qAYXi8ayKusU"
      },
      "outputs": [],
      "source": [
        "generate_text = transformers.pipeline(\n",
        "    model=model, tokenizer=tokenizer,\n",
        "    return_full_text=True,  # langchain expects the full text\n",
        "    task='text-generation',\n",
        "    device=device,\n",
        "    # we pass model parameters here too\n",
        "    stopping_criteria=stopping_criteria,  # without this model will ramble\n",
        "    temperature=0.1,  # 'randomness' of outputs, 0.0 is the min and 1.0 the max\n",
        "    top_p=0.15,  # select from top tokens whose probability add up to 15%\n",
        "    top_k=0,  # select from top 0 tokens (because zero, relies on top_p)\n",
        "    max_new_tokens=150,  # mex number of tokens to generate in the output\n",
        "    repetition_penalty=1.1  # without this output begins repeating\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8DG1WNTnJF1o"
      },
      "source": [
        "Confirm this is working:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lhFgmMr0JHUF",
        "outputId": "722b1a15-2b79-4a9c-e752-f9a1ed142138"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Explain to me the difference between nuclear fission and fusion.\n",
            "\n",
            "A:\n",
            "\n",
            "Fusion is a process by which two nuclei fuse together, releasing energy in the form of heat and light.  Fusion is not an end state; it's a transient state that lasts for only a few microseconds.  It's also very difficult to achieve.  The most common method of achieving fusion is through the use of a neutron bomb.  A neutron bomb uses a fission weapon as its primary payload, but then adds a secondary payload that contains a large amount of deuterium (a heavy isotope of hydrogen) and tritium (another heavy isotope of hydrogen).  When these two elements are combined, they undergo fusion, releasing a huge amount of energy.  This is why neutron\n"
          ]
        }
      ],
      "source": [
        "res = generate_text(\"Explain to me the difference between nuclear fission and fusion.\")\n",
        "print(res[0][\"generated_text\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0N3W3cj3Re1K"
      },
      "source": [
        "Now to implement this in LangChain"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "-8RxQYwHRg0N"
      },
      "outputs": [],
      "source": [
        "from langchain import PromptTemplate, LLMChain\n",
        "from langchain.llms import HuggingFacePipeline\n",
        "\n",
        "# template for an instruction with no input\n",
        "prompt = PromptTemplate(\n",
        "    input_variables=[\"instruction\"],\n",
        "    template=\"{instruction}\"\n",
        ")\n",
        "\n",
        "llm = HuggingFacePipeline(pipeline=generate_text)\n",
        "\n",
        "llm_chain = LLMChain(llm=llm, prompt=prompt)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "208tHnunRngH",
        "outputId": "6d5b703e-13dd-4506-82a4-2fabb8e6a982"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "A:\n",
            "\n",
            "Fusion is a process by which two nuclei fuse together, releasing energy in the form of heat and light.  Fusion is not an end state; it's a transient state that lasts for only a few microseconds.  It's also very difficult to achieve.  The most common method of achieving fusion is through the use of a neutron bomb.  A neutron bomb uses a fission weapon as its primary payload, but then adds a secondary payload that contains a large amount of deuterium (a heavy isotope of hydrogen) and tritium (another heavy isotope of hydrogen).  When these two elements are combined, they undergo fusion, releasing a huge amount of energy.  This is why neutron\n"
          ]
        }
      ],
      "source": [
        "print(llm_chain.predict(\n",
        "    instruction=\"Explain to me the difference between nuclear fission and fusion.\"\n",
        ").lstrip())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5tv0KxJLvsIa"
      },
      "source": [
        "We still get the same output as we're not really doing anything differently here, but we have now added GPT-J-6B-instruct to the LangChain library. Using this we can now begin using LangChain's advanced agent tooling, chains, etc, with GPT-J-6B."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "gpuType": "A100",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}